# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

import os
import pytest
import numpy as np
import mindspore as ms
from numpy import allclose
from mindspore.nn import Cell
from mindspore import context
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.common.tensor import Tensor
import mindspore.communication.management as distributedTool
from mindspore.common.parameter import ParameterTuple, Parameter
from mindspore.ops import composite as C
from mindspore.common import dtype as mstype
from mindspore.train import Model, ParallelMode
from mindspore.nn.optim.momentum import Momentum
from mindspore.train.callback import Callback

np.set_printoptions(threshold=np.inf)
device_num = 2
device_id = int(os.getenv('DEVICE_ID'))
rank_id = 0
embed = 128
classes = 32
batch_size = 32*2
MatmulParamShape = (classes, embed)


def setup_module():
    global device_num
    global rank_id
    np.random.seed(0)
    context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
    context.set_context(enable_task_sink=True,
                        device_id=device_id)
    context.set_context(enable_ir_fusion=True)
    context.set_context(enable_loop_sink=False)
    distributedTool.init()
    rank_id = distributedTool.get_rank()
    device_num = distributedTool.get_group_size()
    context.set_auto_parallel_context(device_num=device_num,
                                      global_rank=device_id)


def teardown_module():
    distributedTool.release()


class DataGenerator():
    def get_parallel_blocks(self, input_, strategy):
        blocks = [input_]
        i = 0
        for stra in strategy:
            temp = []
            while len(blocks) > 0:
                block = blocks.pop(0)
                temp.extend(np.split(block, stra, axis=i))
            blocks.extend(temp)
            i += 1
        return blocks

    def generate_data(self, shape):
        size = np.cumprod(shape)[-1]
        num_range = min(size, 1000)
        data = (np.arange(0, size) % num_range)/num_range
        data = np.reshape(data, shape)
        return data

    def input_data(self, shape):
        data = (self.generate_data(shape)*0.1).astype(np.float32)
        stra = [1]*len(shape)
        stra[0] = device_num
        datas = self.get_parallel_blocks(data, stra)
        return Tensor(data), Tensor(datas[rank_id])

    def label_data(self, shape, embed):
        data = (self.generate_data(shape)*(embed-1)).astype(np.int32)
        stra = [1]*len(shape)
        stra[0] = device_num
        datas = self.get_parallel_blocks(data, stra)
        return Tensor(data), Tensor(datas[rank_id])


class Dataset():
    def __init__(self, predict, label, length=1, input_num=2):
        self.predict = predict
        self.label = label
        self.index = 0
        self.length = length
        self.input_num = input_num

    def __iter__(self):
        return self

    def __next__(self):
        if self.index >= self.length:
            raise StopIteration
        self.index += 1
        if self.input_num == 2:
            return self.predict, self.label
        else:
            return self.predict,

    def reset(self):
        self.index = 0

    def get_dataset_size(self):
        return self.length

    def get_repeat_count(self):
        return self.length


class ModelCallback(Callback):
    def __init__(self):
        super(ModelCallback, self).__init__()
        self.loss_list = []

    def epoch_end(self, run_context, *args):
        cb_params = run_context.original_args()
        result = cb_params.net_outputs
        self.loss_list.append(result.asnumpy().mean())


class SoftmaxCrossEntropyExpand(Cell):
    def __init__(self, sparse=False, stra_list=[]):
        super(SoftmaxCrossEntropyExpand, self).__init__()
        if len(stra_list) < 11:
            stra_list = [None]*11
        self.exp = P.Exp()
        self.reduce_sum = P.ReduceSum(keep_dims=True).set_strategy(strategy=stra_list[1])
        self.onehot = P.OneHot().set_strategy(strategy=stra_list[2])
        self.on_value = Tensor(1.0, mstype.float32)
        self.off_value = Tensor(0.0, mstype.float32)
        self.div = P.Div().set_strategy(strategy=stra_list[3])
        self.log = P.Log().set_strategy(strategy=stra_list[4])
        self.sum_cross_entropy = P.ReduceSum(keep_dims=False).set_strategy(strategy=stra_list[5])
        self.mul = P.Mul().set_strategy(strategy=stra_list[6])
        self.mul2 = P.Mul().set_strategy(strategy=stra_list[7])
        self.cast = P.Cast()
        self.reduce_mean = P.ReduceMean(keep_dims=False).set_strategy(strategy=stra_list[8])
        self.sparse = sparse
        self.reduce_max = P.ReduceMax(keep_dims=True).set_strategy(strategy=stra_list[9])
        self.sub = P.Sub().set_strategy(strategy=stra_list[10])

    def construct(self, logit, label):
        logit_max = self.reduce_max(logit, -1)
        exp = self.exp(self.sub(logit, logit_max))
        exp_sum = self.reduce_sum(exp, -1)
        softmax_result = self.div(exp, exp_sum)
        if self.sparse:
            label = self.onehot(label, F.shape(logit)[1], self.on_value, self.off_value)
        softmax_result_log = self.log(softmax_result)
        loss = self.sum_cross_entropy((self.mul(softmax_result_log, label)), -1)
        loss = self.mul2(F.scalar_to_array(-1.0), loss)
        loss = self.reduce_mean(loss, -1)
        return loss


class MatmulNet(Cell):
    def __init__(self, matmul_stra=None, loss_stra_list=[]):
        super(MatmulNet, self).__init__()
        self.matmul = P.MatMul(transpose_b=True).set_strategy(strategy=matmul_stra)
        self.loss = SoftmaxCrossEntropyExpand(sparse=True, stra_list=loss_stra_list)
        self.weight = Parameter(Tensor(np.ones(MatmulParamShape), dtype=ms.float32), name="weight")

    def construct(self, x, label):
        loss_input = self.matmul(x, self.weight)
        out = self.loss(loss_input, label)
        return out


class LossFactory():
    def __init__(self):
        dataGen = DataGenerator()
        self.input_full, self.input_part = dataGen.input_data((batch_size, embed))
        self.label_full, self.label_part = dataGen.label_data((batch_size,), embed)

    def single_matmul_trains(self):
        single_callback = ModelCallback()
        net = MatmulNet()
        optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
        model = Model(net, optimizer=optimizer)
        epoch_size = 6
        dataset = Dataset(self.input_full, self.label_full)
        model.train(epoch_size, dataset, callbacks=single_callback, dataset_sink_mode=False)
        loss_value = np.array(single_callback.loss_list)
        return loss_value

    def data_parallel_matmul_trains(self):
        parallel_callback = ModelCallback()
        context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
        net = MatmulNet()
        optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
        model = Model(net, optimizer=optimizer)
        epoch_size = 6
        dataset = Dataset(self.input_part, self.label_part)
        model.train(epoch_size, dataset, callbacks=parallel_callback, dataset_sink_mode=False)
        loss_value = np.array(parallel_callback.loss_list)
        return loss_value

    def model_parallel_matmul_trains(self):
        parallel_callback = ModelCallback()
        matmul_stra = ((1, 1), (device_num, 1))
        reduce_max_stra = ((1, device_num),)
        sub_stra = ((1, device_num), (1, 1))
        exp_stra = ((1, device_num),)
        reduce_sum_stra = ((1, device_num),)
        div_stra = ((1, device_num), (1, 1))
        log_stra = ((1, device_num),)
        mul_stra = ((1, device_num), (1, device_num))
        sum_cross_entropy_stra = ((1, device_num),)
        mul2_stra = ((), (device_num,))
        reduce_mean_stra = ((device_num,),)
        onehot_stra = ((1, device_num), (), ())
        loss_stra_list = [exp_stra, reduce_sum_stra, onehot_stra, div_stra, log_stra,
                          sum_cross_entropy_stra, mul_stra, mul2_stra, reduce_mean_stra, reduce_max_stra, sub_stra]
        context.set_auto_parallel_context(parallel_mode="auto_parallel")
        net = MatmulNet(matmul_stra=matmul_stra, loss_stra_list=loss_stra_list)
        optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
        model = Model(net, optimizer=optimizer)
        epoch_size = 6
        dataset = Dataset(self.input_part, self.label_part)
        model.train(epoch_size, dataset, callbacks=parallel_callback, dataset_sink_mode=False)
        loss_value = np.array(parallel_callback.loss_list)
        return loss_value

    def mix_parallel_matmul_trains(self):
        parallel_callback = ModelCallback()
        matmul_stra = ((device_num, 1), (1, 1))
        reduce_max_stra = ((1, device_num),)
        sub_stra = ((device_num, 1), (device_num, 1))
        exp_stra = ((1, device_num),)
        reduce_sum_stra = ((1, device_num),)
        div_stra = ((1, device_num), (1, 1))
        log_stra = ((1, device_num),)
        mul_stra = ((1, device_num), (1, device_num))
        sum_cross_entropy_stra = ((1, device_num),)
        mul2_stra = ((), (device_num,))
        reduce_mean_stra = ((device_num,),)
        onehot_stra = ((1, device_num), (), ())
        loss_stra_list = [exp_stra, reduce_sum_stra, onehot_stra, div_stra, log_stra,
                          sum_cross_entropy_stra, mul_stra, mul2_stra, reduce_mean_stra, reduce_max_stra, sub_stra]
        context.set_auto_parallel_context(parallel_mode="auto_parallel")
        net = MatmulNet(matmul_stra=matmul_stra, loss_stra_list=loss_stra_list)
        optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
        model = Model(net, optimizer=optimizer)
        epoch_size = 6
        dataset = Dataset(self.input_part, self.label_part)
        model.train(epoch_size, dataset, callbacks=parallel_callback, dataset_sink_mode=False)
        loss_value = np.array(parallel_callback.loss_list)
        return loss_value


def test_all_trains():
    loss_factory = LossFactory()
    context.reset_auto_parallel_context()
    single_loss = loss_factory.single_matmul_trains()
    model_parallel_loss = loss_factory.model_parallel_matmul_trains()
    mix_parallel_loss = loss_factory.mix_parallel_matmul_trains()
    assert allclose(single_loss, model_parallel_loss)
    assert allclose(single_loss, mix_parallel_loss)
